//
// Created by Gao Shihao on 2023/8/31.
//

#ifndef PLC2LLVM_BITSTRTYPE_H
#define PLC2LLVM_BITSTRTYPE_H
#include <plc2llvm/TypeSystem/BasicType.h>
#include "plc2llvm/TypeSystem/TypeMachine.h"
#include "plc2llvm/ScopeSystem/ScopeManager.h"

namespace plcst {


    template<TypeKind Kind, int size>
    class BitStrType : public BasicType {
    public:
        
        [[nodiscard]] TypeKind getTypeKind() const override;

        // constructor
        explicit BitStrType(std::string&& name, i64 i = 0) : BasicType(std::move(name)), initValue(i) {}

        explicit BitStrType(BitStrType<Kind, size>* another) : BasicType(another) { }

        virtual TypeMsg typeSynthesisForBinaryOperator(ExpressionOperator op, std::shared_ptr<Type> another) override{
            auto anotherTypeKind = another->getTypeKind();
            auto typeMachine = TypeMachine::getTypeMachine();
            if(!typeMachine.isBitstr(anotherTypeKind)){ // another is not a bit str
                return {nullptr, 2};
            }

            switch(op){
                case ExpressionOperator::EQUAL:
                case ExpressionOperator::NOTEQUAL:{
                    auto boolType = ScopeManager::getScopeManager().getGlobalScope()->find<Type>("BOOL");
                    return {boolType, 0};
                }
                case ExpressionOperator::AND:
                case ExpressionOperator::XOR:
                case ExpressionOperator::OR:
                    return this->getLargerType(another);
                default:
                    return {nullptr, 2};
            }
        }

        virtual TypeMsg typeSynthesisForUnaryOperator(ExpressionOperator op) override {
            auto thisShared = ScopeManager::getScopeManager().getGlobalScope()->find<Type>(this->getTypeName());
            switch(op){
                case ExpressionOperator::NOT:
                    return {thisShared, 0};
                default:
                    return {nullptr, 2};
            }
        }

        [[nodiscard]] int getNBits() const override {
            return nbits;
        }

    private:
        i64 initValue;
        static inline int nbits = size;
        static inline TypeKind kind = Kind;

    public:
        virtual TypeMsg getLargerType(std::shared_ptr<Type> another) override {
            auto thisShared = ScopeManager::getScopeManager().getGlobalScope()->find<Type>(this->getTypeName());
            auto anotherTypeKind = another->getTypeKind();
            TypeMachine& typeMachine = TypeMachine::getTypeMachine();
            if(typeMachine.isBitstr(anotherTypeKind)){ // 如果是bit str，返回更长的bit str
                auto anotherDeatil = std::dynamic_pointer_cast<BasicType>(another);
                auto anotherSize = anotherDeatil->getNBits();
                auto thisSize = this->getNBits();
                if(thisSize > anotherSize){
                    return {thisShared, 0};
                }else{
                    return {another, 0};
                }
            }else{ // 否则返回nullptr
                return {nullptr, 2};
            }
        }

        virtual llvm::Value* castTo(std::shared_ptr<Type> destTy, llvm::Value* src, llvm::IRBuilder<>* builder) {
            auto anotherTypeKind = destTy->getTypeKind();
            if(this->getTypeKind() == anotherTypeKind){
                return src;
            }
            TypeMachine& typeMachine = TypeMachine::getTypeMachine();
            if(typeMachine.isBitstr(anotherTypeKind)){ // 如果是bit str，返回更长的bit str
                return builder->CreateBitCast(src, destTy->llvmty);
            }else{ // 否则返回nullptr
                throw SemanticError("bad bit str cast");
                return nullptr;
            }
        }

        virtual std::shared_ptr<Type> clone() override {
            return std::make_shared<BitStrType<Kind, size>>(this);
        }

    };

    template<TypeKind Kind, int size>
    TypeKind BitStrType<Kind, size>::getTypeKind() const {
        return kind;
    }

}
#endif //PLC2LLVM_BITSTRTYPE_H
